{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "You can run either this notebook locally (if you have all the dependencies and a GPU) or on Google Colab.\n",
    "\n",
    "Instructions for setting up Colab are as follows:\n",
    "1. Open a new Python 3 notebook.\n",
    "2. Import this notebook from GitHub (File -> Upload Notebook -> \"GITHUB\" tab -> copy/paste GitHub URL)\n",
    "3. Connect to an instance with a GPU (Runtime -> Change runtime type -> select \"GPU\" for hardware accelerator)\n",
    "4. Run this cell to set up dependencies.\n",
    "\"\"\"\n",
    "# If you're using Google Colab and not running locally, run this cell.\n",
    "!git clone https://github.com/NVIDIA/NeMo.git\n",
    "NEMO_ROOT = !pwd\n",
    "NEMO_ROOT = NEMO_ROOT[0]\n",
    "!cd NeMo && pip install \".[all]\"\n",
    "!pip install unidecode"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "import os\n",
    "from PIL import Image\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "# this is nemo's \"core\" package\n",
    "import nemo\n",
    "# this is nemos's collection of GAN-related modules used for this example\n",
    "import nemo.collections.simple_gan as nemo_simple_gan"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define model parameters\n",
    "batch_size = 64\n",
    "data_root = \".\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Instantiate necessary Neural Modules"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create a neural factory\n",
    "neural_factory = nemo.core.NeuralModuleFactory()\n",
    "mnist_data = nemo_simple_gan.MnistGanDataLayer(\n",
    "    batch_size=batch_size,\n",
    "    shuffle=True,\n",
    "    train=True,\n",
    "    root=data_root\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "generator = nemo_simple_gan.SimpleGenerator(batch_size=batch_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "discriminator = nemo_simple_gan.SimpleDiscriminator()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "interpolater = nemo_simple_gan.InterpolateImage()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Loss_D = D(interpolated) - D(real) + lambda * GP\n",
    "disc_loss = nemo_simple_gan.DiscriminatorLoss()\n",
    "neg_disc_loss = nemo_simple_gan.DiscriminatorLoss(neg=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "disc_grad_penalty = nemo_simple_gan.GradientPenalty(lambda_=10)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Describe how Neural Modules are connected together"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create generator DAG\n",
    "# Grab data from data layer\n",
    "latents, real_data, _ = mnist_data()\n",
    "# Generate image from latents\n",
    "generated_image = generator(latents=latents)\n",
    "# Define D(G(z)) where z represents the latents\n",
    "# generator_decision is a tensor that represents D(G(z))\n",
    "generator_decision = discriminator(image=generated_image)\n",
    "# Define loss_G = - mean(D(G(z)))\n",
    "generator_loss = neg_disc_loss(decision=generator_decision)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create discriminator DAG\n",
    "# Create interpolated image that is somewhere inbetween the real_image\n",
    "# and the generated image\n",
    "# Note: we reuse the generated_image tensor from the generator DAG\n",
    "interpolated_image = interpolater(image1=real_data, image2=generated_image)\n",
    "# Define D(x~) where x~ is the interpolated image\n",
    "interpolated_decision = discriminator(image=interpolated_image)\n",
    "# Define D(x) where x is the real image\n",
    "real_decision = discriminator(image=real_data)\n",
    "\n",
    "# Define the components of the discriminator loss\n",
    "# interpolated_loss = mean(D(x~))\n",
    "interpolated_loss = disc_loss(decision=interpolated_decision)\n",
    "# real_loss = - mean(D(x))\n",
    "real_loss = neg_disc_loss(decision=real_decision)\n",
    "# grad_penalty = mean(lambda* (|gradients| - 1) ** 2)\n",
    "grad_penalty = disc_grad_penalty(\n",
    "    interpolated_image=interpolated_image,\n",
    "    interpolated_decision=interpolated_decision)\n",
    "# Note the final loss_D = interpolated_loss + real_loss + grad_penalty"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create Eval DAG\n",
    "# Create a new datalayer that samples from the latent distribution\n",
    "random_data = nemo_simple_gan.RandomDataLayer(batch_size=batch_size)\n",
    "# Create a new NmTensor to get data from the data layer\n",
    "latents_e = random_data()\n",
    "# Generate from latents\n",
    "generated_image_e = generator(latents=latents_e)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Run training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer = neural_factory._trainer\n",
    "# Define the losses\n",
    "losses_G = [generator_loss]\n",
    "losses_D = [interpolated_loss, real_loss, grad_penalty]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Since we want optimizers to only operate on a subset of the model, we need\n",
    "# to manually create optimizers\n",
    "# For single loss and single optimizer, the following steps can be skipped\n",
    "# and an optimizer will be created in trainer.train()\n",
    "optimizer_G = trainer.create_optimizer(\n",
    "    optimizer=\"adam\",\n",
    "    things_to_optimize=[generator],\n",
    "    optimizer_params={\n",
    "        \"lr\": 1e-4,\n",
    "        \"betas\": (0.5, 0.9),\n",
    "    })\n",
    "optimizer_D = trainer.create_optimizer(\n",
    "    optimizer=\"adam\",\n",
    "    things_to_optimize=[discriminator],\n",
    "    optimizer_params={\n",
    "        \"lr\": 1e-4,\n",
    "        \"betas\": (0.5, 0.9),\n",
    "    })"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define some helper functions to log generated samples\n",
    "def save_image(global_vars):\n",
    "    images = global_vars[\"image\"]\n",
    "    image = images[0].squeeze(0).detach().cpu().numpy() * 255\n",
    "    plt.imshow(image, cmap=\"gray\")\n",
    "    plt.show()\n",
    "\n",
    "\n",
    "def put_tensor_in_dict(tensors, global_vars):\n",
    "    global_vars[\"image\"] = tensors[generated_image_e.unique_name][0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "# Define a callback that generates samples\n",
    "eval_callback = nemo.core.EvaluatorCallback(\n",
    "    eval_tensors=[generated_image_e],\n",
    "    user_iter_callback=put_tensor_in_dict,\n",
    "    user_epochs_done_callback=lambda x: save_image(x),\n",
    "    eval_step=1000,\n",
    ")\n",
    "\n",
    "# Define our training loop. Here we optimize take 3 discriminator steps\n",
    "# prior to taking the generator step\n",
    "tensors_to_optimize = [\n",
    "    (optimizer_D, losses_D),\n",
    "    (optimizer_D, losses_D),\n",
    "    (optimizer_D, losses_D),\n",
    "    (optimizer_G, losses_G),\n",
    "]\n",
    "\n",
    "# Finally, call train with our training loop and callbacks\n",
    "trainer.train(\n",
    "    tensors_to_optimize=tensors_to_optimize,\n",
    "    callbacks=[eval_callback],\n",
    "    optimization_params={\"num_epochs\": 10})"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
